import warnings
import numpy as np
from util import undersampled, dual_sampling_ratio, afosr_sampling, seq_ofsr_sampling

warnings.filterwarnings("ignore")

def EA(alternative_count, design_indices): #USR Algorithm
    counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
    min_idx = np.argmin(counts)
    next_context, next_design = design_indices[min_idx]
    return next_context, next_design

def AFSR(EXP, alternative_count, n0, d, f, design_indices): #BCSR Algorithm
    total_sample = np.sum(alternative_count)
    undersample = undersampled(d, total_sample, alternative_count, design_indices)

    if np.sum(alternative_count) < d * n0:
        counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
        min_idx = np.argmin(counts)
        next_context, next_design = design_indices[min_idx]

    elif np.any(undersample):
        extracted_values = np.array([alternative_count[i, j] for (i, j) in undersample])
        min_idx = np.argmin(extracted_values)
        next_context, next_design = undersample[min_idx]

    else:
        opt_solution = np.argmax(f, axis=1)
        S_score = np.zeros(d)
        m_indices, k_indices = zip(*EXP.design_indices)
        design_var = EXP.variance[m_indices, k_indices]
        design_count = alternative_count[m_indices, k_indices]
        variance = design_var / design_count
        for h in range(d):
            j, i = design_indices[h]
            if i!= opt_solution[j]:
                S_score[h] = variance[h] / (f[j, opt_solution[j]] - f[j, i])**2
            else:
                S_score[h] = -np.inf
        min_ele = np.max(S_score)
        S_score[np.isinf(S_score)] = min_ele
        max_val = np.max(S_score)
        max_indices = np.where(S_score == max_val)[0]
        random_max_index = np.random.choice(max_indices)
        next_context, next_design = design_indices[random_max_index]
    return next_context, next_design

def AFOSR(EXP, K, M, d, b, f, g, alternative_count, phi, Temp, n0, design_indices):# GFSR Algorithm
    total_sample = np.sum(alternative_count)
    undersample = undersampled(d, total_sample, alternative_count, design_indices)

    if np.sum(alternative_count) < d * n0:
        counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
        min_idx = np.argmin(counts)
        next_context, next_design = design_indices[min_idx]
    elif np.any(undersample):
        extracted_values = np.array([alternative_count[i, j] for (i, j) in undersample])
        min_idx = np.argmin(extracted_values)
        next_context, next_design = undersample[min_idx]
    else:
        next_context, next_design = afosr_sampling(EXP, K, M, b, f, g, alternative_count, phi, Temp, design_indices)
    return next_context, next_design

def SEQ_OFSR(EXP, K, M, d, f, phi, Temp, n0, alternative_count, design_indices):# GOSR Algorithm

    total_sample = np.sum(alternative_count)
    undersample = undersampled(d, total_sample, alternative_count, design_indices)

    if np.sum(alternative_count) < d * n0:
        counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
        min_idx = np.argmin(counts)
        next_context, next_design = design_indices[min_idx]
    elif np.any(undersample):
        extracted_values = np.array([alternative_count[i, j] for (i, j) in undersample])
        min_idx = np.argmin(extracted_values)
        next_context, next_design = undersample[min_idx]
    else:
        next_context, next_design = seq_ofsr_sampling(EXP, K, M, f, alternative_count, phi, Temp, design_indices)
    return next_context, next_design

def DualSR(K, M, d, b, f, g, opt_solution, feasibility, phi, Z_mat, Temp, n0, alternative_count, lambda_, design_indices, design_var, flag): #DSR Algorithm

    total_sample = np.sum(alternative_count)
    undersample = undersampled(d, total_sample, alternative_count, design_indices)

    if np.sum(alternative_count) < d * n0:
        counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
        min_idx = np.argmin(counts)
        next_context, next_design = design_indices[min_idx]
        lambda_ = np.ones(K * M) / (K * M)
    elif np.any(undersample):
        extracted_values = np.array([alternative_count[i, j] for (i, j) in undersample])
        min_idx = np.argmin(extracted_values)
        next_context, next_design = undersample[min_idx]
        lambda_ = None
    else:
        lambda_, static_ratio = dual_sampling_ratio(total_sample, lambda_, M, K, d, b, f, g, opt_solution, feasibility, phi, Z_mat, Temp, design_var, flag)
        counts = np.array([alternative_count[m, k] for (m, k) in design_indices])
        min_idx = np.argmin(counts - total_sample * static_ratio)
        next_context, next_design = design_indices[min_idx]

    return next_context, next_design, lambda_


